{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0XQ6NsIuDtgr"
      },
      "source": [
        "# Illustrated: Self-Attention\n",
        "Step-by-step guide to self-attention with illustrations and code\n",
        "\n",
        "[medium article](https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a)\n",
        "\n",
        "[Article author](https://towardsdatascience.com/@remykarem)\n",
        "\n",
        "> Colab made by: [Manuel Romero](https://twitter.com/mrm8488)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U76qWlrbOmx7"
      },
      "source": [
        "![texto alternativo](https://miro.medium.com/max/1973/1*_92bnsMJy8Bl539G4v93yg.gif)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wOkXKd60Q_Iu"
      },
      "source": [
        "What do *BERT, RoBERTa, ALBERT, SpanBERT, DistilBERT, SesameBERT, SemBERT, MobileBERT, TinyBERT and CamemBERT* all have in common? And I’m not looking for the answer “BERT” 🤭.\n",
        "Answer: **self-attention** 🤗. We are not only talking about architectures bearing the name “BERT’, but more correctly **Transformer-based architectures**. Transformer-based architectures, which are primarily used in modelling language understanding tasks, eschew the use of recurrence in neural network (RNNs) and instead trust entirely on self-attention mechanisms to draw global dependencies between inputs and outputs. But what’s the math behind this?"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yozzTBjBRbAA"
      },
      "source": [
        "The main content of this kernel is to walk you through the mathematical operations involved in a self-attention module."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "atUYzU3TSD9z"
      },
      "source": [
        "### Step 0. What is self-attention?\n",
        "\n",
        "If you’re thinking if self-attention is similar to attention, then the answer is yes! They fundamentally share the same concept and many common mathematical operations.\n",
        "A self-attention module takes in n inputs, and returns n outputs. What happens in this module? In layman’s terms, the self-attention mechanism allows the inputs to interact with each other (“self”) and find out who they should pay more attention to (“attention”). The outputs are aggregates of these interactions and attention scores."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SDMmHAaSTE6P"
      },
      "source": [
        "Following, we are going to explain and implement:\n",
        "- Prepare inputs\n",
        "- Initialise weights\n",
        "- Derive key, query and value\n",
        "- Calculate attention scores for Input 1\n",
        "- Calculate softmax\n",
        "- Multiply scores with values\n",
        "- Sum weighted values to get Output 1\n",
        "- Repeat steps 4–7 for Input 2 & Input 3"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "u1UxPJlHBVmS"
      },
      "source": [
        "import torch"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ENdzUZqSBsiB"
      },
      "source": [
        "### Step 1: Prepare inputs\n",
        "\n",
        "For this tutorial, for the shake of simplicity, we start with 3 inputs, each with dimension 4.\n",
        "\n",
        "![texto alternativo](https://miro.medium.com/max/1973/1*hmvdDXrxhJsGhOQClQdkBA.png)\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jKYrJsljBhnv",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        },
        "outputId": "7b865905-2151-4a6a-a899-5439aa429af4"
      },
      "source": [
        "x = [\n",
        "  [1, 0, 1, 0], # Input 1\n",
        "  [0, 2, 0, 2], # Input 2\n",
        "  [1, 1, 1, 1]  # Input 3\n",
        " ]\n",
        "x = torch.tensor(x, dtype=torch.float32)\n",
        "x"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[1., 0., 1., 0.],\n",
              "        [0., 2., 0., 2.],\n",
              "        [1., 1., 1., 1.]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 2
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DZ96EoE1Bvat"
      },
      "source": [
        "### Step 2: Initialise weights\n",
        "\n",
        "Every input must have three representations (see diagram below). These representations are called **key** (orange), **query** (red), and **value** (purple). For this example, let’s take that we want these representations to have a dimension of 3. Because every input has a dimension of 4, this means each set of the weights must have a shape of 4×3.\n",
        "\n",
        "![texto del enlace](https://miro.medium.com/max/1975/1*VPvXYMGjv0kRuoYqgFvCag.gif)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jUTNr15JBkSG",
        "outputId": "baa4c379-6174-4990-8cd2-51191e904550",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 284
        }
      },
      "source": [
        "w_key = [\n",
        "  [0, 0, 1],\n",
        "  [1, 1, 0],\n",
        "  [0, 1, 0],\n",
        "  [1, 1, 0]\n",
        "]\n",
        "w_query = [\n",
        "  [1, 0, 1],\n",
        "  [1, 0, 0],\n",
        "  [0, 0, 1],\n",
        "  [0, 1, 1]\n",
        "]\n",
        "w_value = [\n",
        "  [0, 2, 0],\n",
        "  [0, 3, 0],\n",
        "  [1, 0, 3],\n",
        "  [1, 1, 0]\n",
        "]\n",
        "w_key = torch.tensor(w_key, dtype=torch.float32)\n",
        "w_query = torch.tensor(w_query, dtype=torch.float32)\n",
        "w_value = torch.tensor(w_value, dtype=torch.float32)\n",
        "\n",
        "print(\"Weights for key: \\n\", w_key)\n",
        "print(\"Weights for query: \\n\", w_query)\n",
        "print(\"Weights for value: \\n\", w_value)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Weights for key: \n",
            " tensor([[0., 0., 1.],\n",
            "        [1., 1., 0.],\n",
            "        [0., 1., 0.],\n",
            "        [1., 1., 0.]])\n",
            "Weights for query: \n",
            " tensor([[1., 0., 1.],\n",
            "        [1., 0., 0.],\n",
            "        [0., 0., 1.],\n",
            "        [0., 1., 1.]])\n",
            "Weights for value: \n",
            " tensor([[0., 2., 0.],\n",
            "        [0., 3., 0.],\n",
            "        [1., 0., 3.],\n",
            "        [1., 1., 0.]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8pr9XZF9X_Ed"
      },
      "source": [
        "Note: *In a neural network setting, these weights are usually small numbers, initialised randomly using an appropriate random distribution like Gaussian, Xavier and Kaiming distributions.*"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UxGT5awVB1Xw"
      },
      "source": [
        "### Step 3: Derive key, query and value\n",
        "\n",
        "Now that we have the three sets of weights, let’s actually obtain the **key**, **query** and **value** representations for every input."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VQwhDIi7aGXp"
      },
      "source": [
        "Obtaining the keys:\n",
        "```\n",
        "               [0, 0, 1]\n",
        "[1, 0, 1, 0]   [1, 1, 0]   [0, 1, 1]\n",
        "[0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0]\n",
        "[1, 1, 1, 1]   [1, 1, 0]   [2, 3, 1]\n",
        "```\n",
        "![texto alternativo](https://miro.medium.com/max/1975/1*dr6NIaTfTxEWzxB2rc0JWg.gif)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qi0EblXTamFz"
      },
      "source": [
        "Obtaining the values:\n",
        "```\n",
        "               [0, 2, 0]\n",
        "[1, 0, 1, 0]   [0, 3, 0]   [1, 2, 3]\n",
        "[0, 2, 0, 2] x [1, 0, 3] = [2, 8, 0]\n",
        "[1, 1, 1, 1]   [1, 1, 0]   [2, 6, 3]\n",
        "```\n",
        "![texto alternativo](https://miro.medium.com/max/1975/1*5kqW7yEwvcC0tjDOW3Ia-A.gif)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GTp2izu1bLNq"
      },
      "source": [
        "Obtaining the querys:\n",
        "```\n",
        "               [1, 0, 1]\n",
        "[1, 0, 1, 0]   [1, 0, 0]   [1, 0, 2]\n",
        "[0, 2, 0, 2] x [0, 0, 1] = [2, 2, 2]\n",
        "[1, 1, 1, 1]   [0, 1, 1]   [2, 1, 3]\n",
        "```\n",
        "![texto alternativo](https://miro.medium.com/max/1975/1*wO_UqfkWkv3WmGQVHvrMJw.gif)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qegb9M0KbnRK"
      },
      "source": [
        "Notes: *Notes\n",
        "In practice, a bias vector may be added to the product of matrix multiplication.*"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rv2NXynOB7oG",
        "outputId": "a2656b52-4b1d-4726-9d42-522f941b3126",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 230
        }
      },
      "source": [
        "keys = x @ w_key\n",
        "querys = x @ w_query\n",
        "values = x @ w_value\n",
        "\n",
        "print(\"Keys: \\n\", keys)\n",
        "# tensor([[0., 1., 1.],\n",
        "#         [4., 4., 0.],\n",
        "#         [2., 3., 1.]])\n",
        "\n",
        "print(\"Querys: \\n\", querys)\n",
        "# tensor([[1., 0., 2.],\n",
        "#         [2., 2., 2.],\n",
        "#         [2., 1., 3.]])\n",
        "print(\"Values: \\n\", values)\n",
        "# tensor([[1., 2., 3.],\n",
        "#         [2., 8., 0.],\n",
        "#         [2., 6., 3.]])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Keys: \n",
            " tensor([[0., 1., 1.],\n",
            "        [4., 4., 0.],\n",
            "        [2., 3., 1.]])\n",
            "Querys: \n",
            " tensor([[1., 0., 2.],\n",
            "        [2., 2., 2.],\n",
            "        [2., 1., 3.]])\n",
            "Values: \n",
            " tensor([[1., 2., 3.],\n",
            "        [2., 8., 0.],\n",
            "        [2., 6., 3.]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3pmf0OQhCnD8"
      },
      "source": [
        "### Step 4: Calculate attention scores\n",
        "![texto alternativo](https://miro.medium.com/max/1973/1*u27nhUppoWYIGkRDmYFN2A.gif)\n",
        "\n",
        "To obtain **attention scores**, we start off with taking a dot product between Input 1’s **query** (red) with **all keys** (orange), including itself. Since there are 3 key representations (because we have 3 inputs), we obtain 3 attention scores (blue).\n",
        "\n",
        "```\n",
        "            [0, 4, 2]\n",
        "[1, 0, 2] x [1, 4, 3] = [2, 4, 4]\n",
        "            [1, 0, 1]\n",
        "```\n",
        "Notice that we only use the query from Input 1. Later we’ll work on repeating this same step for the other querys.\n",
        "\n",
        "Note: *The above operation is known as dot product attention, one of the several score functions. Other score functions include scaled dot product and additive/concat.*            "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6GDhKEl0Cokw",
        "outputId": "c91356df-202c-4816-e98d-eefd1e1031d3",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        }
      },
      "source": [
        "attn_scores = querys @ keys.T\n",
        "print(attn_scores)\n",
        "\n",
        "# tensor([[ 2.,  4.,  4.],  # attention scores from Query 1\n",
        "#         [ 4., 16., 12.],  # attention scores from Query 2\n",
        "#         [ 4., 12., 10.]]) # attention scores from Query 3"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([[ 2.,  4.,  4.],\n",
            "        [ 4., 16., 12.],\n",
            "        [ 4., 12., 10.]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bO3NmnbvCxpX"
      },
      "source": [
        "### Step 5: Calculate softmax\n",
        "![texto alternativo](https://miro.medium.com/max/1973/1*jf__2D8RNCzefwS0TP1Kyg.gif)\n",
        "\n",
        "Take the **softmax** across these **attention scores** (blue).\n",
        "```\n",
        "softmax([2, 4, 4]) = [0.0, 0.5, 0.5]\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PDNzdZHVC1ys",
        "outputId": "c528a7be-5c26-46a9-8fdb-1f2b029b6b93",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 124
        }
      },
      "source": [
        "from torch.nn.functional import softmax\n",
        "\n",
        "attn_scores_softmax = softmax(attn_scores, dim=-1)\n",
        "print(attn_scores_softmax)\n",
        "# tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],\n",
        "#         [6.0337e-06, 9.8201e-01, 1.7986e-02],\n",
        "#         [2.9539e-04, 8.8054e-01, 1.1917e-01]])\n",
        "\n",
        "# For readability, approximate the above as follows\n",
        "attn_scores_softmax = [\n",
        "  [0.0, 0.5, 0.5],\n",
        "  [0.0, 1.0, 0.0],\n",
        "  [0.0, 0.9, 0.1]\n",
        "]\n",
        "attn_scores_softmax = torch.tensor(attn_scores_softmax)\n",
        "print(attn_scores_softmax)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],\n",
            "        [6.0337e-06, 9.8201e-01, 1.7986e-02],\n",
            "        [2.9539e-04, 8.8054e-01, 1.1917e-01]])\n",
            "tensor([[0.0000, 0.5000, 0.5000],\n",
            "        [0.0000, 1.0000, 0.0000],\n",
            "        [0.0000, 0.9000, 0.1000]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iBe71nseDBhb"
      },
      "source": [
        "### Step 6: Multiply scores with values\n",
        "![texto alternativo](https://miro.medium.com/max/1973/1*9cTaJGgXPbiJ4AOCc6QHyA.gif)\n",
        "\n",
        "The softmaxed attention scores for each input (blue) is multiplied with its corresponding **value** (purple). This results in 3 alignment vectors (yellow). In this tutorial, we’ll refer to them as **weighted values**.\n",
        "```\n",
        "1: 0.0 * [1, 2, 3] = [0.0, 0.0, 0.0]\n",
        "2: 0.5 * [2, 8, 0] = [1.0, 4.0, 0.0]\n",
        "3: 0.5 * [2, 6, 3] = [1.0, 3.0, 1.5]\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tNnx-Fx5DFDi",
        "outputId": "abc7a8ec-f964-483a-9bfb-2848f0e8e592",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 212
        }
      },
      "source": [
        "weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]\n",
        "print(weighted_values)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([[[0.0000, 0.0000, 0.0000],\n",
            "         [0.0000, 0.0000, 0.0000],\n",
            "         [0.0000, 0.0000, 0.0000]],\n",
            "\n",
            "        [[1.0000, 4.0000, 0.0000],\n",
            "         [2.0000, 8.0000, 0.0000],\n",
            "         [1.8000, 7.2000, 0.0000]],\n",
            "\n",
            "        [[1.0000, 3.0000, 1.5000],\n",
            "         [0.0000, 0.0000, 0.0000],\n",
            "         [0.2000, 0.6000, 0.3000]]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gU6w0U9ADQIc"
      },
      "source": [
        "### Step 7: Sum weighted values\n",
        "![texto alternativo](https://miro.medium.com/max/1973/1*1je5TwhVAwwnIeDFvww3ew.gif)\n",
        "\n",
        "Take all the **weighted values** (yellow) and sum them element-wise:\n",
        "\n",
        "```\n",
        "  [0.0, 0.0, 0.0]\n",
        "+ [1.0, 4.0, 0.0]\n",
        "+ [1.0, 3.0, 1.5]\n",
        "-----------------\n",
        "= [2.0, 7.0, 1.5]\n",
        "```\n",
        "\n",
        "The resulting vector ```[2.0, 7.0, 1.5]``` (dark green) **is Output 1**, which is based on the **query representation from Input 1** interacting with all other keys, including itself.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P3yNYDUEgAos"
      },
      "source": [
        "### Step 8: Repeat for Input 2 & Input 3\n",
        "![texto alternativo](https://miro.medium.com/max/1973/1*G8thyDVqeD8WHim_QzjvFg.gif)\n",
        "\n",
        "Note: *The dimension of **query** and **key** must always be the same because of the dot product score function. However, the dimension of **value** may be different from **query** and **key**. The resulting output will consequently follow the dimension of **value**.*"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "R6excNSUDRRj",
        "outputId": "e5161fbe-05a5-41d2-da1e-5951ce8b1674",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        }
      },
      "source": [
        "outputs = weighted_values.sum(dim=0)\n",
        "print(outputs)\n",
        "\n",
        "# tensor([[2.0000, 7.0000, 1.5000],  # Output 1\n",
        "#         [2.0000, 8.0000, 0.0000],  # Output 2\n",
        "#         [2.0000, 7.8000, 0.3000]]) # Output 3"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor([[2.0000, 7.0000, 1.5000],\n",
            "        [2.0000, 8.0000, 0.0000],\n",
            "        [2.0000, 7.8000, 0.3000]])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oavQirdbhAK7"
      },
      "source": [
        "### Bonus: Tensorflow 2 implementation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "575q0u_ahP-6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "867a4e88-2223-41e4-ccd5-dbc47f580c83"
      },
      "source": [
        "%tensorflow_version 2.x\n",
        "import tensorflow as tf"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "TensorFlow 2.x selected.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0vjwwEKMhqmZ",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 88
        },
        "outputId": "56e5ed58-e100-434d-a8b2-00325bfc0d40"
      },
      "source": [
        "x = [\n",
        "  [1, 0, 1, 0], # Input 1\n",
        "  [0, 2, 0, 2], # Input 2\n",
        "  [1, 1, 1, 1]  # Input 3\n",
        " ]\n",
        "\n",
        "x = tf.convert_to_tensor(x, dtype=tf.float32)\n",
        "print(x)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tf.Tensor(\n",
            "[[1. 0. 1. 0.]\n",
            " [0. 2. 0. 2.]\n",
            " [1. 1. 1. 1.]], shape=(3, 4), dtype=float32)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TN-pri7rhwJ-",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 337
        },
        "outputId": "aa8b1395-80a3-41e1-b544-beb06ce65a96"
      },
      "source": [
        "w_key = [\n",
        "  [0, 0, 1],\n",
        "  [1, 1, 0],\n",
        "  [0, 1, 0],\n",
        "  [1, 1, 0]\n",
        "]\n",
        "w_query = [\n",
        "  [1, 0, 1],\n",
        "  [1, 0, 0],\n",
        "  [0, 0, 1],\n",
        "  [0, 1, 1]\n",
        "]\n",
        "w_value = [\n",
        "  [0, 2, 0],\n",
        "  [0, 3, 0],\n",
        "  [1, 0, 3],\n",
        "  [1, 1, 0]\n",
        "]\n",
        "w_key = tf.convert_to_tensor(w_key, dtype=tf.float32)\n",
        "w_query = tf.convert_to_tensor(w_query, dtype=tf.float32)\n",
        "w_value = tf.convert_to_tensor(w_value, dtype=tf.float32)\n",
        "print(\"Weights for key: \\n\", w_key)\n",
        "print(\"Weights for query: \\n\", w_query)\n",
        "print(\"Weights for value: \\n\", w_value)\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Weights for key: \n",
            " tf.Tensor(\n",
            "[[0. 0. 1.]\n",
            " [1. 1. 0.]\n",
            " [0. 1. 0.]\n",
            " [1. 1. 0.]], shape=(4, 3), dtype=float32)\n",
            "Weights for query: \n",
            " tf.Tensor(\n",
            "[[1. 0. 1.]\n",
            " [1. 0. 0.]\n",
            " [0. 0. 1.]\n",
            " [0. 1. 1.]], shape=(4, 3), dtype=float32)\n",
            "Weights for value: \n",
            " tf.Tensor(\n",
            "[[0. 2. 0.]\n",
            " [0. 3. 0.]\n",
            " [1. 0. 3.]\n",
            " [1. 1. 0.]], shape=(4, 3), dtype=float32)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Jp2DP46Sh19r",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 230
        },
        "outputId": "5c1befaf-e096-454c-8402-885f049752e0"
      },
      "source": [
        "keys = tf.matmul(x, w_key)\n",
        "querys = tf.matmul(x, w_query)\n",
        "values = tf.matmul(x, w_value)\n",
        "print(keys)\n",
        "print(querys)\n",
        "print(values)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tf.Tensor(\n",
            "[[0. 1. 1.]\n",
            " [4. 4. 0.]\n",
            " [2. 3. 1.]], shape=(3, 3), dtype=float32)\n",
            "tf.Tensor(\n",
            "[[1. 0. 2.]\n",
            " [2. 2. 2.]\n",
            " [2. 1. 3.]], shape=(3, 3), dtype=float32)\n",
            "tf.Tensor(\n",
            "[[1. 2. 3.]\n",
            " [2. 8. 0.]\n",
            " [2. 6. 3.]], shape=(3, 3), dtype=float32)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tLJDo_bFigkm",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 88
        },
        "outputId": "b5d8e02d-9531-49c8-a587-7a6e0b6f884d"
      },
      "source": [
        "attn_scores = tf.matmul(querys, keys, transpose_b=True)\n",
        "print(attn_scores)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tf.Tensor(\n",
            "[[ 2.  4.  4.]\n",
            " [ 4. 16. 12.]\n",
            " [ 4. 12. 10.]], shape=(3, 3), dtype=float32)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8QY858MEiibV",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 159
        },
        "outputId": "2e84f48b-a4ed-4116-8655-21cbb9de8358"
      },
      "source": [
        "attn_scores_softmax = tf.nn.softmax(attn_scores, axis=-1)\n",
        "print(attn_scores_softmax)\n",
        "\n",
        "# For readability, approximate the above as follows\n",
        "attn_scores_softmax = [\n",
        "  [0.0, 0.5, 0.5],\n",
        "  [0.0, 1.0, 0.0],\n",
        "  [0.0, 0.9, 0.1]\n",
        "]\n",
        "attn_scores_softmax = tf.convert_to_tensor(attn_scores_softmax)\n",
        "print(attn_scores_softmax)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tf.Tensor(\n",
            "[[6.3378938e-02 4.6831051e-01 4.6831051e-01]\n",
            " [6.0336647e-06 9.8200780e-01 1.7986100e-02]\n",
            " [2.9538720e-04 8.8053685e-01 1.1916770e-01]], shape=(3, 3), dtype=float32)\n",
            "tf.Tensor(\n",
            "[[0.  0.5 0.5]\n",
            " [0.  1.  0. ]\n",
            " [0.  0.9 0.1]], shape=(3, 3), dtype=float32)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TOJMfkFpi0KQ",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 230
        },
        "outputId": "8de18989-50d7-4534-cf5c-2711c66d17ce"
      },
      "source": [
        "weighted_values = values[:,None] * tf.transpose(attn_scores_softmax)[:,:,None]\n",
        "print(weighted_values)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tf.Tensor(\n",
            "[[[0.  0.  0. ]\n",
            "  [0.  0.  0. ]\n",
            "  [0.  0.  0. ]]\n",
            "\n",
            " [[1.  4.  0. ]\n",
            "  [2.  8.  0. ]\n",
            "  [1.8 7.2 0. ]]\n",
            "\n",
            " [[1.  3.  1.5]\n",
            "  [0.  0.  0. ]\n",
            "  [0.2 0.6 0.3]]], shape=(3, 3, 3), dtype=float32)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jan_cyy7i-s7",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 88
        },
        "outputId": "09b1406f-3a08-47e2-8dee-d4d6334ef1de"
      },
      "source": [
        "outputs = tf.reduce_sum(weighted_values, axis=0)  # 6\n",
        "print(outputs)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tf.Tensor(\n",
            "[[2.        7.        1.5      ]\n",
            " [2.        8.        0.       ]\n",
            " [2.        7.7999997 0.3      ]], shape=(3, 3), dtype=float32)\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}